{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7yuytuIllsv1"
      },
      "source": [
        "# Trax Layers Intro\n",
        "\n",
        "This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:\n",
        "\n",
        "  1. **Layers**: the basic building blocks and how to combine them\n",
        "  1. **Inputs and Outputs**: how data streams flow through layers\n",
        "  1. **Defining New Layer Classes** (if combining existing layers isn't enough)\n",
        "  1. **Testing and Debugging Layer Classes**\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BIl27504La0G"
      },
      "source": [
        "**General Setup**\n",
        "\n",
        "Execute the following few cells (once) before running any of the code samples in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oILRLCWN_16u"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {
          "height": 51
        },
        "colab_type": "code",
        "id": "vlGjGoGMTt-D",
        "outputId": "76b95a37-3f1b-4748-bef0-646858f33e25"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "/bin/sh: pip: command not found\n",
            "/bin/sh: pip: command not found\n"
          ]
        }
      ],
      "source": [
        "# Import Trax\n",
        "\n",
        "! pip install -q -U trax\n",
        "! pip install -q tensorflow\n",
        "\n",
        "from trax import fastmath\n",
        "from trax import layers as tl\n",
        "from trax import shapes\n",
        "from trax.fastmath import numpy as jnp  # For use in defining new layer types.\n",
        "from trax.shapes import ShapeDtype\n",
        "from trax.shapes import signature"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "bYWNWL9MJHv9"
      },
      "outputs": [],
      "source": [
        "# Settings and utilities for handling inputs, outputs, and object properties.\n",
        "\n",
        "np.set_printoptions(precision=3)  # Reduce visual noise from extra digits.\n",
        "\n",
        "def show_layer_properties(layer_obj, layer_name):\n",
        "  template = ('{}.n_in:  {}\\n'\n",
        "              '{}.n_out: {}\\n'\n",
        "              '{}.sublayers: {}\\n'\n",
        "              '{}.weights:    {}\\n')\n",
        "  print(template.format(layer_name, layer_obj.n_in,\n",
        "                        layer_name, layer_obj.n_out,\n",
        "                        layer_name, layer_obj.sublayers,\n",
        "                        layer_name, layer_obj.weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-LQ89rFFsEdk"
      },
      "source": [
        "## 1. Layers\n",
        "\n",
        "The Layer class represents Trax's basic building blocks:\n",
        "```\n",
        "class Layer:\n",
        "  \"\"\"Base class for composable layers in a deep learning network.\n",
        "\n",
        "  Layers are the basic building blocks for deep learning models. A Trax layer\n",
        "  computes a function from zero or more inputs to zero or more outputs,\n",
        "  optionally using trainable weights (common) and non-parameter state (not\n",
        "  common).  ...\n",
        "\n",
        "  ...\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LyLVtdxorDPO"
      },
      "source": [
        "### Layers compute functions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ntZ4_eNQldzL"
      },
      "source": [
        "A layer computes a function from zero or more inputs to zero or more outputs.\n",
        "The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.\n",
        "\n",
        "The simplest layers, those with no weights or sublayers, can be used without\n",
        "initialization. You can think of them as (pure) mathematical functions that can\n",
        "be plugged into neural networks.\n",
        "\n",
        "For ease of testing and interactive exploration, layer objects implement the\n",
        "`__call__ ` method, so you can call them directly on input data:\n",
        "```\n",
        "y = my_layer(x)\n",
        "```\n",
        "\n",
        "Layers are also objects, so you can inspect their properties. For example:\n",
        "```\n",
        "print(f'Number of inputs expected by this layer: {my_layer.n_in}')\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hCoapc5le8B7"
      },
      "source": [
        "**Example 1.** tl.Relu $[n_{in} = 1, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 224
        },
        "colab_type": "code",
        "id": "V09viOSEQvQe",
        "outputId": "a0134cee-0db8-4396-825e-93e695a42ca5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[ -2  -1   0   1   2]\n",
            " [-20 -10   0  10  20]]\n",
            "\n",
            "relu(x):\n",
            "[[ 0  0  0  1  2]\n",
            " [ 0  0  0 10 20]]\n",
            "\n",
            "Number of inputs expected by this layer: 1\n",
            "Number of outputs promised by this layer: 1\n"
          ]
        }
      ],
      "source": [
        "relu = tl.Relu()\n",
        "\n",
        "x = np.array([[-2, -1, 0, 1, 2],\n",
        "              [-20, -10, 0, 10, 20]])\n",
        "y = relu(x)\n",
        "\n",
        "# Show input, output, and two layer properties.\n",
        "print(f'x:\\n{x}\\n\\n'\n",
        "      f'relu(x):\\n{y}\\n\\n'\n",
        "      f'Number of inputs expected by this layer: {relu.n_in}\\n'\n",
        "      f'Number of outputs promised by this layer: {relu.n_out}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7sYxIT8crFVE"
      },
      "source": [
        "**Example 2.** tl.Concatenate $[n_{in} = 2, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 255
        },
        "colab_type": "code",
        "id": "LMPPNWXLoOZI",
        "outputId": "42f595b1-4014-429a-a0b3-2c12d630cd32"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x0:\n",
            "[[1 2 3]\n",
            " [4 5 6]]\n",
            "\n",
            "x1:\n",
            "[[10 20 30]\n",
            " [40 50 60]]\n",
            "\n",
            "concat([x1, x2]):\n",
            "[[ 1  2  3 10 20 30]\n",
            " [ 4  5  6 40 50 60]]\n",
            "\n",
            "Number of inputs expected by this layer: 2\n",
            "Number of outputs promised by this layer: 1\n"
          ]
        }
      ],
      "source": [
        "concat = tl.Concatenate()\n",
        "\n",
        "x0 = np.array([[1, 2, 3],\n",
        "               [4, 5, 6]])\n",
        "x1 = np.array([[10, 20, 30],\n",
        "               [40, 50, 60]])\n",
        "y = concat([x0, x1])\n",
        "\n",
        "print(f'x0:\\n{x0}\\n\\n'\n",
        "      f'x1:\\n{x1}\\n\\n'\n",
        "      f'concat([x1, x2]):\\n{y}\\n\\n'\n",
        "      f'Number of inputs expected by this layer: {concat.n_in}\\n'\n",
        "      f'Number of outputs promised by this layer: {concat.n_out}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "z7N1qe91eYyM"
      },
      "source": [
        "### Layers are configurable.\n",
        "\n",
        "Many layer types have creation-time parameters for flexibility. The \n",
        "`Concatenate` layer type, for instance, has two optional parameters:\n",
        "\n",
        "*   `axis`: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.\n",
        "*   `n_items`: number of tensors to join into one by concatenation; default value is 2.\n",
        "\n",
        "The following example shows `Concatenate` configured for **3** input tensors,\n",
        "and concatenation along the initial $(0^{th})$ axis."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "l53Jw23pZ4s6"
      },
      "source": [
        "**Example 3.** tl.Concatenate(n_items=3, axis=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 340
        },
        "colab_type": "code",
        "id": "bhhWlVLffZtf",
        "outputId": "5a8afaa1-66c8-47fe-abcc-e7cfa33bb28c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x0:\n",
            "[[1 2 3]\n",
            " [4 5 6]]\n",
            "\n",
            "x1:\n",
            "[[10 20 30]\n",
            " [40 50 60]]\n",
            "\n",
            "x2:\n",
            "[[100 200 300]\n",
            " [400 500 600]]\n",
            "\n",
            "concat3([x0, x1, x2]):\n",
            "[[  1   2   3]\n",
            " [  4   5   6]\n",
            " [ 10  20  30]\n",
            " [ 40  50  60]\n",
            " [100 200 300]\n",
            " [400 500 600]]\n"
          ]
        }
      ],
      "source": [
        "concat3 = tl.Concatenate(n_items=3, axis=0)\n",
        "\n",
        "x0 = np.array([[1, 2, 3],\n",
        "               [4, 5, 6]])\n",
        "x1 = np.array([[10, 20, 30],\n",
        "               [40, 50, 60]])\n",
        "x2 = np.array([[100, 200, 300],\n",
        "               [400, 500, 600]])\n",
        "\n",
        "y = concat3([x0, x1, x2])\n",
        "\n",
        "print(f'x0:\\n{x0}\\n\\n'\n",
        "      f'x1:\\n{x1}\\n\\n'\n",
        "      f'x2:\\n{x2}\\n\\n'\n",
        "      f'concat3([x0, x1, x2]):\\n{y}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1oZv3R8bRMvF"
      },
      "source": [
        "### Layers are trainable.\n",
        "\n",
        "Many layer types include weights that affect the computation of outputs from\n",
        "inputs, and they use back-progagated gradients to update those weights.\n",
        "\n",
        "🚧🚧 *A very small subset of layer types, such as `BatchNorm`, also include\n",
        "modifiable weights (called `state`) that are updated based on forward-pass\n",
        "inputs/computation rather than back-propagated gradients.*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3d64M7wLryji"
      },
      "source": [
        "**Initialization**\n",
        "\n",
        "Trainable layers must be initialized before use. Trax can take care of this\n",
        "as part of the overall training process. In other settings (e.g., in tests or\n",
        "interactively in a Colab notebook), you need to initialize the\n",
        "*outermost/topmost* layer explicitly. For this, use `init`:\n",
        "\n",
        "```\n",
        "  def init(self, input_signature, rng=None, use_cache=False):\n",
        "    \"\"\"Initializes weights/state of this layer and its sublayers recursively.\n",
        "\n",
        "    Initialization creates layer weights and state, for layers that use them.\n",
        "    It derives the necessary array shapes and data types from the layer's input\n",
        "    signature, which is itself just shape and data type information.\n",
        "\n",
        "    For layers without weights or state, this method safely does nothing.\n",
        "\n",
        "    This method is designed to create weights/state only once for each layer\n",
        "    instance, even if the same layer instance occurs in multiple places in the\n",
        "    network. This enables weight sharing to be implemented as layer sharing.\n",
        "\n",
        "    Args:\n",
        "      input_signature: `ShapeDtype` instance (if this layer takes one input)\n",
        "          or list/tuple of `ShapeDtype` instances.\n",
        "      rng: Single-use random number generator (JAX PRNG key), or `None`;\n",
        "          if `None`, use a default computed from an integer 0 seed.\n",
        "      use_cache: If `True`, and if this layer instance has already been\n",
        "          initialized elsewhere in the network, then return special marker\n",
        "          values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.\n",
        "          Else return this layer's newly initialized weights and state.\n",
        "\n",
        "    Returns:\n",
        "      A `(weights, state)` tuple.\n",
        "    \"\"\"\n",
        "```\n",
        "\n",
        "Input signatures can be built from scratch using `ShapeDType` objects, or can\n",
        "be derived from data via the `signature` function (in module `shapes`):\n",
        "```\n",
        "def signature(obj):\n",
        "  \"\"\"Returns a `ShapeDtype` signature for the given `obj`.\n",
        "\n",
        "  A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`\n",
        "  instances. Note that this function is permissive with respect to its inputs\n",
        "  (accepts lists or tuples or dicts, and underlying objects can be any type\n",
        "  as long as they have shape and dtype attributes) and returns the corresponding\n",
        "  nested structure of `ShapeDtype`.\n",
        "\n",
        "  Args:\n",
        "    obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict\n",
        "        of such objects.\n",
        "\n",
        "  Returns:\n",
        "    A corresponding nested structure of `ShapeDtype` instances.\n",
        "  \"\"\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yL8HAj6GEAp1"
      },
      "source": [
        "**Example 4.** tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 221
        },
        "colab_type": "code",
        "id": "Ie7iyX91qAx2",
        "outputId": "0efecdf5-c0a4-4304-f442-d12fc1a51253"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[-2. -1.  0.  1.  2.]\n",
            " [ 1.  2.  3.  4.  5.]\n",
            " [10. 20. 30. 40. 50.]]\n",
            "\n",
            "layer_norm(x):\n",
            "[[-1.414 -0.707  0.     0.707  1.414]\n",
            " [-1.414 -0.707  0.     0.707  1.414]\n",
            " [-1.414 -0.707  0.     0.707  1.414]]\n",
            "\n",
            "layer_norm.weights:\n",
            "(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))\n"
          ]
        }
      ],
      "source": [
        "layer_norm = tl.LayerNorm()\n",
        "\n",
        "x = np.array([[-2, -1, 0, 1, 2],\n",
        "              [1, 2, 3, 4, 5],\n",
        "              [10, 20, 30, 40, 50]]).astype(np.float32)\n",
        "layer_norm.init(shapes.signature(x))\n",
        "\n",
        "y = layer_norm(x)\n",
        "\n",
        "print(f'x:\\n{x}\\n\\n'\n",
        "      f'layer_norm(x):\\n{y}\\n')\n",
        "print(f'layer_norm.weights:\\n{layer_norm.weights}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "d47gVdGV1vWw"
      },
      "source": [
        "### Layers combine into layers.\n",
        "\n",
        "The Trax library authors encourage users to build networks and network\n",
        "components as combinations of existing layers, by means of a small set of\n",
        "_combinator_ layers. A combinator makes a list of layers behave as a single\n",
        "layer -- by combining the sublayer computations yet looking from the outside\n",
        "like any other layer. The combined layer, like other layers, can:\n",
        "\n",
        "* compute outputs from inputs,\n",
        "* update parameters from gradients, and\n",
        "* combine with yet more layers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vC1ymG2j0iyp"
      },
      "source": [
        "**Combine with `Serial`**\n",
        "\n",
        "The most common way to combine layers is with the `Serial` combinator:\n",
        "```\n",
        "class Serial(base.Layer):\n",
        "  \"\"\"Combinator that applies layers serially (by function composition).\n",
        "\n",
        "  This combinator is commonly used to construct deep networks, e.g., like this::\n",
        "\n",
        "      mlp = tl.Serial(\n",
        "        tl.Dense(128),\n",
        "        tl.Relu(),\n",
        "        tl.Dense(10),\n",
        "        tl.LogSoftmax()\n",
        "      )\n",
        "\n",
        "  A Serial combinator uses stack semantics to manage data for its sublayers.\n",
        "  Each sublayer sees only the inputs it needs and returns only the outputs it\n",
        "  has generated. The sublayers interact via the data stack. For instance, a\n",
        "  sublayer k, following sublayer j, gets called with the data stack in the\n",
        "  state left after layer j has applied. The Serial combinator then:\n",
        "\n",
        "    - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n",
        "      layer k, passing those items as arguments; and\n",
        "\n",
        "    - takes layer k's n_out return values (n_out = k.n_out) and pushes\n",
        "      them onto the data stack.\n",
        "\n",
        "  A Serial instance with no sublayers acts as a special-case (but useful)\n",
        "  1-input 1-output no-op.\n",
        "  \"\"\"\n",
        "```\n",
        "If one layer has the same number of outputs as the next layer has inputs (which\n",
        "is the usual case), the successive layers behave like function composition:\n",
        "\n",
        "```\n",
        "#  h(.) = g(f(.))\n",
        "layer_h = Serial(\n",
        "    layer_f,\n",
        "    layer_g,\n",
        ")\n",
        "```\n",
        "Note how, inside `Serial`, function composition is expressed naturally as a\n",
        "succession of operations, so that no nested parentheses are needed.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uPOnrDa9ViPi"
      },
      "source": [
        "**Example 5.** y = layer_norm(relu(x)) $[n_{in} = 1, n_{out} = 1]$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 136
        },
        "colab_type": "code",
        "id": "dW5fpusjvjmh",
        "outputId": "acdcffe7-23d5-4ecd-df9b-32f48ae77959"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[ -2.  -1.   0.   1.   2.]\n",
            " [-20. -10.   0.  10.  20.]]\n",
            "\n",
            "layer_block(x):\n",
            "[[-0.75 -0.75 -0.75  0.5   1.75]\n",
            " [-0.75 -0.75 -0.75  0.5   1.75]]\n"
          ]
        }
      ],
      "source": [
        "layer_block = tl.Serial(\n",
        "    tl.Relu(),\n",
        "    tl.LayerNorm(),\n",
        ")\n",
        "\n",
        "x = np.array([[-2, -1, 0, 1, 2],\n",
        "              [-20, -10, 0, 10, 20]]).astype(np.float32)\n",
        "layer_block.init(shapes.signature(x))\n",
        "y = layer_block(x)\n",
        "\n",
        "print(f'x:\\n{x}\\n\\n'\n",
        "      f'layer_block(x):\\n{y}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bRtmN6ckQO1q"
      },
      "source": [
        "And we can inspect the block as a whole, as if it were just another layer:\n",
        "\n",
        "**Example 5'.** Inspecting a `Serial` layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 68
        },
        "colab_type": "code",
        "id": "D6BpYddZQ1eu",
        "outputId": "1a00c9f2-63a0-450c-d902-c9baf06dc917"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "layer_block: Serial[\n",
            "  Relu\n",
            "  LayerNorm\n",
            "]\n",
            "\n",
            "layer_block.weights: ((), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))\n"
          ]
        }
      ],
      "source": [
        "print(f'layer_block: {layer_block}\\n\\n'\n",
        "      f'layer_block.weights: {layer_block.weights}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kJ8bpYZtE66x"
      },
      "source": [
        "**Combine with `Branch`**\n",
        "\n",
        "The `Branch` combinator arranges layers into parallel computational channels:\n",
        "```\n",
        "def Branch(*layers, name='Branch'):\n",
        "  \"\"\"Combinator that applies a list of layers in parallel to copies of inputs.\n",
        "\n",
        "  Each layer in the input list is applied to as many inputs from the stack\n",
        "  as it needs, and their outputs are successively combined on stack.\n",
        "\n",
        "  For example, suppose one has three layers:\n",
        "\n",
        "    - F: 1 input, 1 output\n",
        "    - G: 3 inputs, 1 output\n",
        "    - H: 2 inputs, 2 outputs (h1, h2)\n",
        "\n",
        "  Then Branch(F, G, H) will take 3 inputs and give 4 outputs:\n",
        "\n",
        "    - inputs: a, b, c\n",
        "    - outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)\n",
        "\n",
        "  As an important special case, a None argument to Branch acts as if it takes\n",
        "  one argument, which it leaves unchanged. (It acts as a one-arg no-op.)\n",
        "\n",
        "  Args:\n",
        "    *layers: List of layers.\n",
        "    name: Descriptive name for this layer.\n",
        "\n",
        "  Returns:\n",
        "    A branch layer built from the given sublayers.\n",
        "  \"\"\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RlPcnRtdIVgq"
      },
      "source": [
        "Residual blocks, for example, are implemented using `Branch`:\n",
        "```\n",
        "def Residual(*layers, shortcut=None):\n",
        "  \"\"\"Wraps a series of layers with a residual connection.\n",
        "\n",
        "  Args:\n",
        "    *layers: One or more layers, to be applied in series.\n",
        "    shortcut: If None (the usual case), the Residual layer computes the\n",
        "        element-wise sum of the stack-top input with the output of the layer\n",
        "        series. If specified, the `shortcut` layer applies to a copy of the\n",
        "        inputs and (elementwise) adds its output to the output from the main\n",
        "        layer series.\n",
        "\n",
        "  Returns:\n",
        "      A layer representing a residual connection paired with a layer series.\n",
        "  \"\"\"\n",
        "  layers = _ensure_flat(layers)\n",
        "  layer = layers[0] if len(layers) == 1 else Serial(layers)\n",
        "  return Serial(\n",
        "      Branch(shortcut, layer),\n",
        "      Add(),\n",
        "  )\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ruX4aFMdUOwS"
      },
      "source": [
        "Here's a simple code example to highlight the mechanics."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JGGnKjg4ESIg"
      },
      "source": [
        "**Example 6.** `Branch`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 204
        },
        "colab_type": "code",
        "id": "lw6A2YwuW-Ul",
        "outputId": "a07ef350-bafa-4fa7-a083-19e6f725b3ce"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[ -2  -1   0   1   2]\n",
            " [-20 -10   0  10  20]]\n",
            "\n",
            "y0:\n",
            "[[ 0  0  0  1  2]\n",
            " [ 0  0  0 10 20]]\n",
            "\n",
            "y1:\n",
            "[[ -200.  -100.     0.   100.   200.]\n",
            " [-2000. -1000.     0.  1000.  2000.]]\n"
          ]
        }
      ],
      "source": [
        "relu = tl.Relu()\n",
        "times_100 = tl.Fn(\"Times100\", lambda x: x * 100.0)\n",
        "branch_relu_t100 = tl.Branch(relu, times_100)\n",
        "\n",
        "x = np.array([[-2, -1, 0, 1, 2],\n",
        "              [-20, -10, 0, 10, 20]])\n",
        "branch_relu_t100.init(shapes.signature(x))\n",
        "\n",
        "y0, y1 = branch_relu_t100(x)\n",
        "\n",
        "print(f'x:\\n{x}\\n\\n'\n",
        "      f'y0:\\n{y0}\\n\\n'\n",
        "      f'y1:\\n{y1}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zr2ZZ1vO8T8V"
      },
      "source": [
        "## 2. Inputs and Outputs\n",
        "\n",
        "Trax allows layers to have multiple input streams and output streams. When\n",
        "designing a network, you have the flexibility to use layers that:\n",
        "\n",
        "  - process a single data stream ($n_{in} = n_{out} = 1$),\n",
        "  - process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),\n",
        "  - split or inject data streams ($n_{in} \u003c n_{out}$), or\n",
        "  - merge or remove data streams ($n_{in} \u003e n_{out}$).\n",
        "\n",
        "We saw in section 1 the example of `Residual`, which involves both a split and a merge:\n",
        "```\n",
        "  ...\n",
        "  return Serial(\n",
        "      Branch(shortcut, layer),\n",
        "      Add(),\n",
        "  )\n",
        "```\n",
        "In other words, layer by layer:\n",
        "\n",
        "  - `Branch(shortcut, layers)`: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$n_{in} = 1$, $n_{out} = 2$]\n",
        "  - `Add()`: combines the two streams back into one by adding two tensors elementwise. [$n_{in} = 2$, $n_{out} = 1$]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1FEttSCVVM3T"
      },
      "source": [
        "### Data Stack\n",
        "\n",
        "Trax supports flexible data flows through a network via a data stack, which is\n",
        "managed by the `Serial` combinator:\n",
        "```\n",
        "class Serial(base.Layer):\n",
        "  \"\"\"Combinator that applies layers serially (by function composition).\n",
        "\n",
        "  ...\n",
        "\n",
        "  A Serial combinator uses stack semantics to manage data for its sublayers.\n",
        "  Each sublayer sees only the inputs it needs and returns only the outputs it\n",
        "  has generated. The sublayers interact via the data stack. For instance, a\n",
        "  sublayer k, following sublayer j, gets called with the data stack in the\n",
        "  state left after layer j has applied. The Serial combinator then:\n",
        "\n",
        "    - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n",
        "      layer k, passing those items as arguments; and\n",
        "\n",
        "    - takes layer k's n_out return values (n_out = k.n_out) and pushes\n",
        "      them onto the data stack.\n",
        "\n",
        "  ...\n",
        "\n",
        "  \"\"\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5DAiajI-Gzk4"
      },
      "source": [
        "**Simple Case 1 -- Each layer takes one input and has one output.**\n",
        "\n",
        "This is in effect a single data stream pipeline, and the successive layers\n",
        "behave like function composition:\n",
        "\n",
        "```\n",
        "#  s(.) = h(g(f(.)))\n",
        "layer_s = Serial(\n",
        "    layer_f,\n",
        "    layer_g,\n",
        "    layer_h,\n",
        ")\n",
        "```\n",
        "Note how, inside `Serial`, function composition is expressed naturally as a\n",
        "succession of operations, so that no nested parentheses are needed and the\n",
        "order of operations matches the textual order of layers.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WR8bh64tIzIY"
      },
      "source": [
        "**Simple Case 2 -- Each layer consumes all outputs of the preceding layer.**\n",
        "\n",
        "This is still a single pipeline, but data streams internal to it can split and\n",
        "merge. The `Residual` example above illustrates this kind.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ACG88RdtLbvG"
      },
      "source": [
        "**General Case -- Successive layers interact via the data stack.**\n",
        "\n",
        "As described in the `Serial` class docstring, each layer gets its inputs from\n",
        "the data stack after the preceding layer has put its outputs onto the stack.\n",
        "This covers the simple cases above, but also allows for more flexible data\n",
        "interactions between non-adjacent layers. The following example is schematic:\n",
        "```\n",
        "x, y_target = get_batch_of_labeled_data()\n",
        "\n",
        "model_plus_eval = Serial(\n",
        "    my_fancy_deep_model(),  # Takes one arg (x) and has one output (y_hat)\n",
        "    my_eval(),  # Takes two args (y_hat, y_target) and has one output (score)\n",
        ")\n",
        "\n",
        "eval_score = model_plus_eval((xs, targets))\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "66hUOOYRQqej"
      },
      "source": [
        "Here is the corresponding progression of stack states:\n",
        "\n",
        "0. At start: _--empty--_\n",
        "0. After `get_batch_of_labeled_data()`: *x*, *y_target*\n",
        "0. After `my_fancy_deep_model()`: *y_hat*, *y_target*\n",
        "0. After `my_eval()`: *score*\n",
        "\n",
        "Note in particular how the application of the model (between stack states 1\n",
        "and 2) only uses and affects the top element on the stack: `x` --\u003e `y_hat`.\n",
        "The rest of the data stack (`y_target`) comes in use only later, for the\n",
        "eval function."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "65ite-671cTT"
      },
      "source": [
        "## 3. Defining New Layer Classes\n",
        "\n",
        "If you need a layer type that is not easily defined as a combination of\n",
        "existing layer types, you can define your own layer classes in a couple\n",
        "different ways."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hHSaD9H6hDTf"
      },
      "source": [
        "### With the `Fn` layer-creating function.\n",
        "\n",
        "Many layer types needed in deep learning compute pure functions from inputs to\n",
        "outputs, using neither weights nor randomness. You can use Trax's `Fn` function\n",
        "to define your own pure layer types:\n",
        "```\n",
        "def Fn(name, f, n_out=1):  # pylint: disable=invalid-name\n",
        "  \"\"\"Returns a layer with no weights that applies the function `f`.\n",
        "\n",
        "  `f` can take and return any number of arguments, and takes only positional\n",
        "  arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).\n",
        "  The following, for example, would create a layer that takes two inputs and\n",
        "  returns two outputs -- element-wise sums and maxima:\n",
        "\n",
        "      `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`\n",
        "\n",
        "  The layer's number of inputs (`n_in`) is automatically set to number of\n",
        "  positional arguments in `f`, but you must explicitly set the number of\n",
        "  outputs (`n_out`) whenever it's not the default value 1.\n",
        "\n",
        "  Args:\n",
        "    name: Class-like name for the resulting layer; for use in debugging.\n",
        "    f: Pure function from input tensors to output tensors, where each input\n",
        "        tensor is a separate positional arg, e.g., `f(x0, x1) --\u003e x0 + x1`.\n",
        "        Output tensors must be packaged as specified in the `Layer` class\n",
        "        docstring.\n",
        "    n_out: Number of outputs promised by the layer; default value 1.\n",
        "\n",
        "  Returns:\n",
        "    Layer executing the function `f`.\n",
        "  \"\"\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TX30lGLXcjB1"
      },
      "source": [
        "**Example 7.** Use `Fn` to define a new layer type:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 153
        },
        "colab_type": "code",
        "id": "vKrc6XMV9ErS",
        "outputId": "13f74094-e43e-4267-9055-f3d55d58ae53"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x0:\n",
            "[ 1  2  3  4  5  6  7  8  9 10]\n",
            "\n",
            "x1:\n",
            "[11 12 13 14 15 16 17 18 19 20]\n",
            "\n",
            "gcd((x0, x1)):\n",
            "[ 1  2  1  2  5  2  1  2  1 10]\n"
          ]
        }
      ],
      "source": [
        "# Define new layer type.\n",
        "def Gcd():\n",
        "  \"\"\"Returns a layer to compute the greatest common divisor, elementwise.\"\"\"\n",
        "  return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))\n",
        "\n",
        "# Use it.\n",
        "gcd = Gcd()\n",
        "\n",
        "x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n",
        "x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])\n",
        "\n",
        "y = gcd((x0, x1))\n",
        "\n",
        "print(f'x0:\\n{x0}\\n\\n'\n",
        "      f'x1:\\n{x1}\\n\\n'\n",
        "      f'gcd((x0, x1)):\\n{y}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "W74Eehgp5A57"
      },
      "source": [
        "The `Fn` function infers `n_in` (number of inputs) as the length of `f`'s arg\n",
        "list. `Fn` does not infer `n_out` (number out outputs) though. If your `f` has\n",
        "more than one output, you need to give an explicit value using the `n_out`\n",
        "keyword arg."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2lCjml7SCR-u"
      },
      "source": [
        "**Example 8.** `Fn` with multiple outputs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 204
        },
        "colab_type": "code",
        "id": "rfnA2B9ZczWK",
        "outputId": "9ffd7648-ffda-453e-b88b-4aa4ba8ea482"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x0:\n",
            "[1 2 3 4 5]\n",
            "\n",
            "x1:\n",
            "[ 10 -20  30 -40  50]\n",
            "\n",
            "y0:\n",
            "[ 11 -18  33 -36  55]\n",
            "\n",
            "y1:\n",
            "[10  2 30  4 50]\n"
          ]
        }
      ],
      "source": [
        "# Define new layer type.\n",
        "def SumAndMax():\n",
        "  \"\"\"Returns a layer to compute sums and maxima of two input tensors.\"\"\"\n",
        "  return tl.Fn('SumAndMax',\n",
        "               lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),\n",
        "               n_out=2)\n",
        "\n",
        "# Use it.\n",
        "sum_and_max = SumAndMax()\n",
        "\n",
        "x0 = np.array([1, 2, 3, 4, 5])\n",
        "x1 = np.array([10, -20, 30, -40, 50])\n",
        "\n",
        "y0, y1 = sum_and_max([x0, x1])\n",
        "\n",
        "print(f'x0:\\n{x0}\\n\\n'\n",
        "      f'x1:\\n{x1}\\n\\n'\n",
        "      f'y0:\\n{y0}\\n\\n'\n",
        "      f'y1:\\n{y1}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GrXQUSbKDs41"
      },
      "source": [
        "**Example 9.** Use `Fn` to define a configurable layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 374
        },
        "colab_type": "code",
        "id": "h1KwpmFpEIK3",
        "outputId": "9f6e7009-04a0-46c9-b005-35c091f720eb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x:\n",
            "[[[  1   2   3]\n",
            "  [ 10  20  30]\n",
            "  [100 200 300]]\n",
            "\n",
            " [[  4   5   6]\n",
            "  [ 40  50  60]\n",
            "  [400 500 600]]]\n",
            "\n",
            "flatten_keep_1_axis(x):\n",
            "[[  1   2   3  10  20  30 100 200 300]\n",
            " [  4   5   6  40  50  60 400 500 600]]\n",
            "\n",
            "flatten_keep_2_axes(x):\n",
            "[[[  1   2   3]\n",
            "  [ 10  20  30]\n",
            "  [100 200 300]]\n",
            "\n",
            " [[  4   5   6]\n",
            "  [ 40  50  60]\n",
            "  [400 500 600]]]\n"
          ]
        }
      ],
      "source": [
        "# Function defined in trax/layers/core.py:\n",
        "def Flatten(n_axes_to_keep=1):\n",
        "  \"\"\"Returns a layer that combines one or more trailing axes of a tensor.\n",
        "\n",
        "  Flattening keeps all the values of the input tensor, but reshapes it by\n",
        "  collapsing one or more trailing axes into a single axis. For example, a\n",
        "  `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape\n",
        "  `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.\n",
        "\n",
        "  Args:\n",
        "    n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;\n",
        "        collapse only the axes after these.\n",
        "  \"\"\"\n",
        "  layer_name = f'Flatten_keep{n_axes_to_keep}'\n",
        "  def f(x):\n",
        "    in_rank = len(x.shape)\n",
        "    if in_rank \u003c= n_axes_to_keep:\n",
        "      raise ValueError(f'Input rank ({in_rank}) must exceed the number of '\n",
        "                       f'axes to keep ({n_axes_to_keep}) after flattening.')\n",
        "    return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))\n",
        "  return tl.Fn(layer_name, f)\n",
        "\n",
        "flatten_keep_1_axis = Flatten(n_axes_to_keep=1)\n",
        "flatten_keep_2_axes = Flatten(n_axes_to_keep=2)\n",
        "\n",
        "x = np.array([[[1, 2, 3],\n",
        "               [10, 20, 30],\n",
        "               [100, 200, 300]],\n",
        "              [[4, 5, 6],\n",
        "               [40, 50, 60],\n",
        "               [400, 500, 600]]])\n",
        "\n",
        "y1 = flatten_keep_1_axis(x)\n",
        "y2 = flatten_keep_2_axes(x)\n",
        "\n",
        "print(f'x:\\n{x}\\n\\n'\n",
        "      f'flatten_keep_1_axis(x):\\n{y1}\\n\\n'\n",
        "      f'flatten_keep_2_axes(x):\\n{y2}')\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cqM6WJwNhoHI"
      },
      "source": [
        "### By defining a `Layer` subclass\n",
        "\n",
        "If you need a layer type that uses trainable weights (or state), you can extend\n",
        "the base `Layer` class:\n",
        "```\n",
        "class Layer:\n",
        "  \"\"\"Base class for composable layers in a deep learning network.\n",
        "\n",
        "  ...\n",
        "\n",
        "  Authors of new layer subclasses typically override at most two methods of\n",
        "  the base `Layer` class:\n",
        "\n",
        "    `forward(inputs)`:\n",
        "      Computes this layer's output as part of a forward pass through the model.\n",
        "\n",
        "    `init_weights_and_state(self, input_signature)`:\n",
        "      Initializes weights and state for inputs with the given signature.\n",
        "\n",
        "  ...\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tZlzxNUigD_4"
      },
      "source": [
        "The `forward` method uses *weights stored in the layer object* (`self.weights`)\n",
        "to compute outputs from inputs. For example, here is the definition of\n",
        "`forward` for Trax's `Dense` layer:\n",
        "```\n",
        "  def forward(self, x):\n",
        "    \"\"\"Executes this layer as part of a forward pass through the model.\n",
        "\n",
        "    Args:\n",
        "      x: Tensor of same shape and dtype as the input signature used to\n",
        "          initialize this layer.\n",
        "\n",
        "    Returns:\n",
        "      Tensor of same shape and dtype as the input, except the final dimension\n",
        "      is the layer's `n_units` value.\n",
        "    \"\"\"\n",
        "    if self._use_bias:\n",
        "      if not isinstance(self.weights, (tuple, list)):\n",
        "        raise ValueError(f'Weights should be a (w, b) tuple or list; '\n",
        "                         f'instead got: {self.weights}')\n",
        "      w, b = self.weights\n",
        "      return jnp.dot(x, w) + b  # Affine map.\n",
        "    else:\n",
        "      w = self.weights\n",
        "      return jnp.dot(x, w)  # Linear map.\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PJEEyX9_iPbk"
      },
      "source": [
        "Layer weights must be initialized before the layer can be used; the\n",
        "`init_weights_and_state` method specifies how. Continuing the `Dense` example,\n",
        "here is the corresponding initialization code:\n",
        "```\n",
        "  def init_weights_and_state(self, input_signature):\n",
        "    \"\"\"Randomly initializes this layer's weights.\n",
        "\n",
        "    Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the\n",
        "    default case), or a `w` tensor for layers created with `use_bias=False`.\n",
        "\n",
        "    Args:\n",
        "      input_signature: `ShapeDtype` instance characterizing the input this layer\n",
        "          should compute on.\n",
        "    \"\"\"\n",
        "    shape_w = (input_signature.shape[-1], self._n_units)\n",
        "    shape_b = (self._n_units,)\n",
        "    rng_w, rng_b = fastmath.random.split(self.rng, 2)\n",
        "    w = self._kernel_initializer(shape_w, rng_w)\n",
        "\n",
        "    if self._use_bias:\n",
        "      b = self._bias_initializer(shape_b, rng_b)\n",
        "      self.weights = (w, b)\n",
        "    else:\n",
        "      self.weights = w\n",
        "\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "D77mYZZD41QO"
      },
      "source": [
        "### By defining a `Combinator` subclass\n",
        "\n",
        "*TBD*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PgdQvZ5G6Aei"
      },
      "source": [
        "## 4. Testing and Debugging Layer Classes\n",
        "\n",
        "*TBD*"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Trax Layers Intro",
      "provenance": [
        {
          "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt",
          "timestamp": 1569980697572
        },
        {
          "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5",
          "timestamp": 1563927451951
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
